
import torch
import torch.nn as nn

###############
##### RNN #####
###############

class _RNN(nn.Module):
    r'''
    Recurrent Neural Networks
    -------------------------
    Seq Encoding
    refer: 
    1. _RNNModule *****
    in darts-master\darts\models\forecasting\rnn_model.py
    2. RNNModel **
    in gluonts\model\deep_factor\RNNModel.py
    '''
    def __init__(self, name='GRU', input_size: int = 3, # target_size: int = 3, n_in: int = 8, n_out: int = 1, 
        hidden_layer_size=128, num_layers=1, bidirectional=False, dropout=0.
    ):
        super(_RNN, self).__init__() # 'GRU'，'LSTM', 100, 1, False
        self.input_size = input_size
        self.D = 2 if bidirectional else 1

        self.rnn = getattr(nn, name)(input_size=input_size, 
                            hidden_size=hidden_layer_size, 
                            num_layers=num_layers, 
                            bidirectional=bidirectional, 
                            batch_first=True,
                            dropout=dropout)#.cuda()
        self.decoder = nn.Linear(self.D*hidden_layer_size, input_size)#.cuda() # target_size=1, univariate
        # self.fc = nn.Linear(n_in, n_out)#.cuda()     (self.D*hidden_layer_size)
        ### refer TS2Vec
        # self.tsdecoder = TSDecoder(n_in=n_in, n_out=n_out, input_size=input_size, target_size=target_size)#, dec_mode='fusing2'


    def forward(self, input_seq):
        # batch_size = input_seq.shape[0] # len(input_seq)
        self.rnn.flatten_parameters()
        rnn_out, last_hidden_state = self.rnn(input_seq) #.reshape(batch_size, -1, self.input_size)
        predictions = self.decoder(rnn_out) #.view(batch_size, -1)
        
        return predictions # self.tsdecoder(predictions)

if __name__ == '__main__':
    input_window = timesteps = n_in = seq_in_len = 64
    # output_window = prediction_horizon = seq_out_len = 3 #
    num_nodes = n_timeseries = 41
    batch_size = bs = 50

    torch.set_num_threads(4)
    x = torch.randn(batch_size, seq_in_len, num_nodes)

    model = _RNN(
        input_size=num_nodes, #target_size=num_nodes, n_in=seq_in_len, n_out=seq_out_len, 
    )#.to(device)
    y=model(x)
    print(y.shape)
